In this case study, we train the model with $16 \times 16$ MNIST images and evaluate its performance.
import numpy as np
import pandas as pd
import scipy.ndimage
from tqdm import tqdm
from joblib import Parallel, delayed
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
import tensorflow as tf
pio.renderers.default = "notebook"
from src.imgen import ImageGenerator
from src.utils import tqdm_joblib, plot_dist
NUM_CPUS = 12
NUM_IMAGES = 36
NUM_EPOCHS = 300
NUM_QUBITS = 8
NUM_LAYERS = 2
EPOCH_SAMPLE_SIZE = 10**4
BATCH_SAMPLE_SIZE = 10**3
We use TensorFlow to directly load MNIST images. The dimensions of each image is $28 \times 28$.
(x_train_raw, y_train), (x_test_raw, y_test) = tf.keras.datasets.mnist.load_data()
x_train_raw[0].shape
(28, 28)
In order to fit the images onto $8$ cubits, we downscale them to $16 \times 16$ with cubic interpolation.
x_train = np.array([scipy.ndimage.zoom(x, 0.58, order=3) for x in x_train_raw])
x_test = np.array([scipy.ndimage.zoom(x, 0.58, order=3) for x in x_test_raw])
x_train[0].shape
(16, 16)
This is a sample MNIST image depicting the number $5$.
px.imshow(x_train[0], width=500, height=500, color_continuous_scale='sunset')
Define a helper function that enables parallelization in the training process.
def load_and_train(x, num_epochs):
imgen = ImageGenerator(
NUM_QUBITS, NUM_LAYERS,
epoch_sample_size=EPOCH_SAMPLE_SIZE, batch_sample_size=BATCH_SAMPLE_SIZE,
enable_remapping=True
)
imgen.load_image(x, show_figure=False)
imgen.train(imgen.make_dataset(), num_epochs, show_progress=False)
return imgen.get_output_distribution_history(), imgen.get_real_distribution()
Generate $36$ random indices to sample the dataset.
ind_arr = np.random.choice(range(len(x_train)), size=NUM_IMAGES, replace=False)
ind_arr
array([16538, 41182, 53986, 52774, 30964, 52735, 29028, 29694, 8096,
658, 30286, 17367, 44924, 7405, 7967, 34070, 56154, 39991,
5919, 17316, 56549, 18974, 481, 14617, 34437, 3858, 27050,
26580, 24310, 38478, 52109, 43954, 13892, 3644, 23220, 18097])
Train the model in parallel. This is gonna take a while, so grab a cup of coffee while you're at it!
with tqdm_joblib(tqdm(desc="Training models in parallel", total=NUM_IMAGES)) as progress_bar:
od_hist_arr = Parallel(n_jobs=NUM_CPUS)(delayed(load_and_train)(x_train[i], NUM_EPOCHS) for i in ind_arr)
Training models in parallel: 100%|██████████| 36/36 [17:52<00:00, 29.78s/it]
Define the cross-entropy between two discrete probability distributions.
# https://stackoverflow.com/questions/47377222/what-is-the-problem-with-my-implementation-of-the-cross-entropy-function
def cross_entropy(predictions, targets, epsilon=1e-12):
"""
Computes cross entropy between targets (encoded as one-hot vectors)
and predictions.
Input: predictions (N, k) ndarray
targets (N, k) ndarray
Returns: scalar
"""
predictions = np.clip(predictions, epsilon, 1. - epsilon)
N = predictions.shape[0]
ce = -np.sum(targets*np.log(predictions+1e-9))/N
return ce
Compute and plot the evolution of cross-entropy for each image, as well as the average across all images.
ce_arr = [[cross_entropy(od_hist_arr[i][0][j].flatten(), od_hist_arr[i][1].flatten()) for j in range(NUM_EPOCHS)] for i in range(NUM_IMAGES)]
df = pd.DataFrame(np.array(ce_arr).T)
all_fig = px.line(
df, labels={'index': 'Step', 'value': 'Cross Entropy'},
title='Evolution of Cross Entropy'
)
all_fig.update_layout(
showlegend=False
)
mean_fig = px.line(
df.mean(axis=1), labels={'index': 'Step', 'value': 'Cross Entropy'},
title='Evolution of Mean Cross Entropy'
)
mean_fig.update_layout(
showlegend=False
)
Observe that the mean cross-entropy decreases rapidly as the number of iteration increases. The value stabilizes at around $1.5 \times 10^{-2}$.
The generated images closedly resemble the original. While some generation artifacts can be observed in the low-probability background, the high-probability foreground is clearly distinguishable.
NUM_ROWS = 6
NUM_COLS = 6
output_collage = make_subplots(rows=NUM_ROWS, cols=NUM_COLS)
for r in range(NUM_ROWS):
for c in range(NUM_COLS):
output_collage.add_trace(plot_dist(od_hist_arr[r * NUM_COLS + c][0][-1]).data[0], row=r+1, col=c+1)
output_collage.update_layout(
width=700, height=700,
)
real_collage = make_subplots(rows=NUM_ROWS, cols=NUM_COLS)
for r in range(NUM_ROWS):
for c in range(NUM_COLS):
output_collage.add_trace(plot_dist(od_hist_arr[r * NUM_COLS + c][1]).data[0], row=r+1, col=c+1)
output_collage.update_layout(
width=700, height=700,
)